{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exploring the Impact of Evaluation Order on the Wagner Fisher Algorithm for Levenshtein Edit Distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def algo_v0(s1, s2) -> int:\n",
    "    # Create a matrix of size (len(s1)+1) x (len(s2)+1)\n",
    "    matrix = np.zeros((len(s1) + 1, len(s2) + 1), dtype=int)\n",
    "\n",
    "    # Initialize the first column and first row of the matrix\n",
    "    for i in range(len(s1) + 1):\n",
    "        matrix[i, 0] = i\n",
    "    for j in range(len(s2) + 1):\n",
    "        matrix[0, j] = j\n",
    "\n",
    "    # Compute Levenshtein distance\n",
    "    for i in range(1, len(s1) + 1):\n",
    "        for j in range(1, len(s2) + 1):\n",
    "            substitution_cost = s1[i - 1] != s2[j - 1]\n",
    "            matrix[i, j] = min(\n",
    "                matrix[i - 1, j] + 1,  # Deletion\n",
    "                matrix[i, j - 1] + 1,  # Insertion\n",
    "                matrix[i - 1, j - 1] + substitution_cost,  # Substitution\n",
    "            )\n",
    "\n",
    "    # Return the Levenshtein distance\n",
    "    return matrix[len(s1), len(s2)], matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Accelerating this exact algorithm isn't trivial, is the `matrix[i, j]` value has a dependency on the `matrix[i, j-1]` value.\n",
    "So we can't brute-force accelerate the inner loop.\n",
    "Instead, we can show that we can evaluate the matrix in a different order, and still get the same result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://mathworld.wolfram.com/images/eps-svg/SkewDiagonal_1000.svg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def algo_v1(s1, s2, verbose: bool = False) -> int:\n",
    "    assert len(s1) == len(s2), \"First define an algo for square matrices!\"\n",
    "    # Create a matrix of size (len(s1)+1) x (len(s2)+1)\n",
    "    matrix = np.zeros((len(s1) + 1, len(s2) + 1), dtype=int)\n",
    "    matrix[:, :] = 99\n",
    "\n",
    "    # Initialize the first column and first row of the matrix\n",
    "    for i in range(len(s1) + 1):\n",
    "        matrix[i, 0] = i\n",
    "    for j in range(len(s2) + 1):\n",
    "        matrix[0, j] = j\n",
    "\n",
    "    # Number of rows and columns in the square matrix.\n",
    "    n = len(s1) + 1\n",
    "    skew_diagonals_count = 2 * n - 1\n",
    "    # Compute Levenshtein distance\n",
    "    for skew_diagonal_idx in range(2, skew_diagonals_count):\n",
    "        skew_diagonal_length = (skew_diagonal_idx + 1) if skew_diagonal_idx < n else (2*n - skew_diagonal_idx - 1)\n",
    "        for offset_within_skew_diagonal in range(skew_diagonal_length):\n",
    "            if skew_diagonal_idx < n:\n",
    "                # If we passed the main skew diagonal yet, \n",
    "                # Then we have to skip the first and the last operation,\n",
    "                # as those are already pre-populated and form the first column \n",
    "                # and the first row of the Levenshtein matrix respectively.\n",
    "                if offset_within_skew_diagonal == 0 or offset_within_skew_diagonal + 1 == skew_diagonal_length:\n",
    "                    continue      \n",
    "                i = skew_diagonal_idx - offset_within_skew_diagonal\n",
    "                j = offset_within_skew_diagonal\n",
    "                if verbose:\n",
    "                    print(f\"top left triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n",
    "            else:\n",
    "                i = n - offset_within_skew_diagonal - 1\n",
    "                j = skew_diagonal_idx - n + offset_within_skew_diagonal + 1\n",
    "                if verbose:\n",
    "                    print(f\"bottom right triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n",
    "            substitution_cost = s1[i - 1] != s2[j - 1]\n",
    "            matrix[i, j] = min(\n",
    "                matrix[i - 1, j] + 1,  # Deletion\n",
    "                matrix[i, j - 1] + 1,  # Insertion\n",
    "                matrix[i - 1, j - 1] + substitution_cost,  # Substitution\n",
    "            )\n",
    "\n",
    "    # Return the Levenshtein distance\n",
    "    return matrix[len(s1), len(s2)], matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's generate some random strings and make sure we produce the right result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(10):\n",
    "    s1 = ''.join(random.choices(\"ab\", k=50))\n",
    "    s2 = ''.join(random.choices(\"ab\", k=50))\n",
    "    d0, _ = algo_v0(s1, s2)\n",
    "    d1, _ = algo_v1(s1, s2)\n",
    "    assert d0 == d1 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Going further, we can avoid storing the whole matrix, and only store two diagonals at a time.\n",
    "The longer will never exceed N. The shorter one is always at most N-1, and is always shorter by one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('listen',\n",
       " 'silent',\n",
       " 'distance = 4',\n",
       " array([[0, 1, 2, 3, 4, 5, 6],\n",
       "        [1, 1, 2, 2, 3, 4, 5],\n",
       "        [2, 2, 1, 2, 3, 4, 5],\n",
       "        [3, 2, 2, 2, 3, 4, 5],\n",
       "        [4, 3, 3, 3, 3, 4, 4],\n",
       "        [5, 4, 4, 4, 3, 4, 5],\n",
       "        [6, 5, 5, 5, 4, 3, 4]]))"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1 = \"listen\"\n",
    "s2 = \"silent\"\n",
    "# s1 = ''.join(random.choices(\"abcd\", k=100))\n",
    "# s2 = ''.join(random.choices(\"abcd\", k=100))\n",
    "distance, baseline = algo_v0(s1, s2)\n",
    "s1, s2, f\"{distance = }\", baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0, 0, 0, 0, 0, 0, 0], dtype=uint64),\n",
       " array([1, 1, 0, 0, 0, 0, 0], dtype=uint64),\n",
       " array([0, 0, 0, 0, 0, 0, 0], dtype=uint64))"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "assert len(s1) == len(s2), \"First define an algo for square matrices!\"\n",
    "# Number of rows and columns in the square matrix.\n",
    "n = len(s1) + 1\n",
    "\n",
    "# Let's use just a couple of arrays to store the previous skew diagonals.\n",
    "# Let's imagine that our Levenshtein matrix is gonna have 5x5 size for two words of length 4.\n",
    "#         B C D E << s2 characters: BCDE\n",
    "#     + ---------\n",
    "#     | a b c d e\n",
    "#   F | f g h i j\n",
    "#   K | k l m n o\n",
    "#   P | p q r s t\n",
    "#   U | u v w x y\n",
    "#   ^\n",
    "#   ^ s1 characters: FKPU\n",
    "following = np.zeros(n, dtype=np.uint) # let's assume we are computing the main skew diagonal: [u, q, m, i, e]\n",
    "current = np.zeros(n, dtype=np.uint) # will contain: [p, l, h, e]\n",
    "previous = np.zeros(n, dtype=np.uint) # will contain: [k, g, c]\n",
    "\n",
    "# Initialize the first two diagonals.\n",
    "# The `previous` would contain the values [a].\n",
    "# The `current` would contain the values [f, b]. \n",
    "previous[0] = 0\n",
    "current[0:2] = 1\n",
    "previous, current, following"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To feel safer, while designing our alternative traversal algorithm, let's define an extraction function, that will get the values of a certain skewed diagonal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_skewed_diagonal(matrix: np.ndarray, index: int):\n",
    "    flipped_matrix = np.fliplr(matrix)\n",
    "    return np.flip(np.diag(flipped_matrix, k= matrix.shape[1] - index - 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "matrix = np.array([[1, 2, 3],\n",
    "                   [4, 5, 6],\n",
    "                   [7, 8, 9]])\n",
    "assert np.all(get_skewed_diagonal(matrix, 2) == [7, 5, 3])\n",
    "assert np.all(get_skewed_diagonal(matrix, 1) == [4, 2])\n",
    "assert np.all(get_skewed_diagonal(matrix, 4) == [9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([5, 3, 2, 2, 3, 5, 0], dtype=uint64),\n",
       " array([6, 4, 3, 2, 3, 4, 6], dtype=uint64),\n",
       " array([6, 4, 3, 2, 3, 4, 6], dtype=uint64))"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# To evaluate every subsequent entry:\n",
    "following_skew_diagonal_idx = 2\n",
    "while following_skew_diagonal_idx < n:\n",
    "    following_skew_diagonal_length = following_skew_diagonal_idx + 1\n",
    "\n",
    "    old_substitution_costs = previous[:following_skew_diagonal_length - 2]\n",
    "    added_substitution_costs = [s1[following_skew_diagonal_idx - i - 2] != s2[i] for i in range(following_skew_diagonal_length - 2)]\n",
    "    substitution_costs = old_substitution_costs + added_substitution_costs\n",
    "\n",
    "    following[1:following_skew_diagonal_length-1] = np.minimum(current[1:following_skew_diagonal_length-1] + 1, current[:following_skew_diagonal_length-2] + 1) # Insertions or deletions\n",
    "    following[1:following_skew_diagonal_length-1] = np.minimum(following[1:following_skew_diagonal_length-1], substitution_costs) # Substitutions\n",
    "    following[0] = following_skew_diagonal_idx\n",
    "    following[following_skew_diagonal_length-1] = following_skew_diagonal_idx\n",
    "    assert np.all(following[:following_skew_diagonal_length] == get_skewed_diagonal(baseline, following_skew_diagonal_idx))\n",
    "    \n",
    "    previous[:] = current[:]\n",
    "    current[:] = following[:]\n",
    "    following_skew_diagonal_idx += 1\n",
    "\n",
    "previous, current, following # Log the state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal index on either side, we will be cropping those values out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([5, 4, 5, 5, 5, 6, 0], dtype=uint64),\n",
       " array([4, 5, 4, 5, 5, 5, 6], dtype=uint64),\n",
       " array([4, 5, 4, 5, 5, 5, 6], dtype=uint64))"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "while following_skew_diagonal_idx < 2 * n - 1:\n",
    "    following_skew_diagonal_length = 2 * n - 1 - following_skew_diagonal_idx\n",
    "    old_substitution_costs = previous[:following_skew_diagonal_length]\n",
    "    added_substitution_costs = [s1[len(s1) - i - 1] != s2[following_skew_diagonal_idx - n + i] for i in range(following_skew_diagonal_length)]\n",
    "    substitution_costs = old_substitution_costs + added_substitution_costs\n",
    "    \n",
    "    following[:following_skew_diagonal_length] = np.minimum(current[:following_skew_diagonal_length] + 1, current[1:following_skew_diagonal_length+1] + 1) # Insertions or deletions\n",
    "    following[:following_skew_diagonal_length] = np.minimum(following[:following_skew_diagonal_length], substitution_costs) # Substitutions\n",
    "    assert np.all(following[:following_skew_diagonal_length] == get_skewed_diagonal(baseline, following_skew_diagonal_idx)), f\"\\n{following[:following_skew_diagonal_length]} not equal to \\n{get_skewed_diagonal(baseline, following_skew_diagonal_idx)}\"\n",
    "    \n",
    "    previous[:following_skew_diagonal_length] = current[1:following_skew_diagonal_length+1]\n",
    "    current[:following_skew_diagonal_length] = following[:following_skew_diagonal_length]\n",
    "    following_skew_diagonal_idx += 1\n",
    "\n",
    "previous, current, following # Log the state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert distance == following[0], f\"{distance = } != {following[0] = }\""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
